import torch

x = torch.arange(4.0)
print(x)
x.requires_grad_(True)  # 等价于x=torch.arange(4.0,requires_grad=True)
x.grad  # 默认值是None
y = 2 * torch.dot(x, x)
print(y)
y.backward()
print(x.grad)
print(x.grad == 4 * x)


x.grad.zero_()
y = x.sum()
print(y)
y.backward()
print(x.grad)

X = torch.arange(20).reshape(5,4)
B = torch.arange(4).reshape(4,1)

print(X.mm(B))